import operator
from functools import reduce
import sys

try:
    import unittest2 as unittest
except ImportError:
    import unittest
import hypothesis.strategies as st
import pytest
from hypothesis import given, settings, example

try:
    from hypothesis import HealthCheck

    HC_PRESENT = True
except ImportError:  # pragma: no cover
    HC_PRESENT = False
from .numbertheory import (
    SquareRootError,
    JacobiError,
    factorization,
    gcd,
    lcm,
    jacobi,
    inverse_mod,
    is_prime,
    next_prime,
    smallprimes,
    square_root_mod_prime,
)

try:
    from gmpy2 import mpz
except ImportError:
    try:
        from gmpy import mpz
    except ImportError:

        def mpz(x):
            return x


BIGPRIMES = (
    999671,
    999683,
    999721,
    999727,
    999749,
    999763,
    999769,
    999773,
    999809,
    999853,
    999863,
    999883,
    999907,
    999917,
    999931,
    999953,
    999959,
    999961,
    999979,
    999983,
)


@pytest.mark.parametrize(
    "prime, next_p", [(p, q) for p, q in zip(BIGPRIMES[:-1], BIGPRIMES[1:])]
)
def test_next_prime(prime, next_p):
    assert next_prime(prime) == next_p


@pytest.mark.parametrize("val", [-1, 0, 1])
def test_next_prime_with_nums_less_2(val):
    assert next_prime(val) == 2


@pytest.mark.slow
@pytest.mark.parametrize("prime", smallprimes)
def test_square_root_mod_prime_for_small_primes(prime):
    squares = set()
    for num in range(0, 1 + prime // 2):
        sq = num * num % prime
        squares.add(sq)
        root = square_root_mod_prime(sq, prime)
        # tested for real with TestNumbertheory.test_square_root_mod_prime
        assert root * root % prime == sq

    for nonsquare in range(0, prime):
        if nonsquare in squares:
            continue
        with pytest.raises(SquareRootError):
            square_root_mod_prime(nonsquare, prime)


def test_square_root_mod_prime_for_2():
    a = square_root_mod_prime(1, 2)
    assert a == 1


def test_square_root_mod_prime_for_small_prime():
    root = square_root_mod_prime(98**2 % 101, 101)
    assert root * root % 101 == 9


def test_square_root_mod_prime_for_p_congruent_5():
    p = 13
    assert p % 8 == 5

    root = square_root_mod_prime(3, p)
    assert root * root % p == 3


def test_square_root_mod_prime_for_p_congruent_5_large_d():
    p = 29
    assert p % 8 == 5

    root = square_root_mod_prime(4, p)
    assert root * root % p == 4


class TestSquareRootModPrime(unittest.TestCase):
    def test_power_of_2_p(self):
        with self.assertRaises(JacobiError):
            square_root_mod_prime(12, 32)

    def test_no_square(self):
        with self.assertRaises(SquareRootError) as e:
            square_root_mod_prime(12, 31)

        self.assertIn("no square root", str(e.exception))

    def test_non_prime(self):
        with self.assertRaises(SquareRootError) as e:
            square_root_mod_prime(12, 33)

        self.assertIn("p is not prime", str(e.exception))

    def test_non_prime_with_negative(self):
        with self.assertRaises(SquareRootError) as e:
            square_root_mod_prime(697 - 1, 697)

        self.assertIn("p is not prime", str(e.exception))


@st.composite
def st_two_nums_rel_prime(draw):
    # 521-bit is the biggest curve we operate on, use 1024 for a bit
    # of breathing space
    mod = draw(st.integers(min_value=2, max_value=2**1024))
    num = draw(
        st.integers(min_value=1, max_value=mod - 1).filter(
            lambda x: gcd(x, mod) == 1
        )
    )
    return num, mod


@st.composite
def st_primes(draw, *args, **kwargs):
    if "min_value" not in kwargs:  # pragma: no branch
        kwargs["min_value"] = 1
    prime = draw(
        st.sampled_from(smallprimes)
        | st.integers(*args, **kwargs).filter(is_prime)
    )
    return prime


@st.composite
def st_num_square_prime(draw):
    prime = draw(st_primes(max_value=2**1024))
    num = draw(st.integers(min_value=0, max_value=1 + prime // 2))
    sq = num * num % prime
    return sq, prime


@st.composite
def st_comp_with_com_fac(draw):
    """
    Strategy that returns lists of numbers, all having a common factor.
    """
    primes = draw(
        st.lists(st_primes(max_value=2**512), min_size=1, max_size=10)
    )
    # select random prime(s) that will make the common factor of composites
    com_fac_primes = draw(
        st.lists(st.sampled_from(primes), min_size=1, max_size=20)
    )
    com_fac = reduce(operator.mul, com_fac_primes, 1)

    # select at most 20 lists (returned numbers),
    # each having at most 30 primes (factors) including none (then the number
    # will be 1)
    comp_primes = draw(  # pragma: no branch
        st.integers(min_value=1, max_value=20).flatmap(
            lambda n: st.lists(
                st.lists(st.sampled_from(primes), max_size=30),
                min_size=1,
                max_size=n,
            )
        )
    )

    return [reduce(operator.mul, nums, 1) * com_fac for nums in comp_primes]


@st.composite
def st_comp_no_com_fac(draw):
    """
    Strategy that returns lists of numbers that don't have a common factor.
    """
    primes = draw(
        st.lists(
            st_primes(max_value=2**512), min_size=2, max_size=10, unique=True
        )
    )
    # first select the primes that will create the uncommon factor
    # between returned numbers
    uncom_fac_primes = draw(
        st.lists(
            st.sampled_from(primes),
            min_size=1,
            max_size=len(primes) - 1,
            unique=True,
        )
    )
    uncom_fac = reduce(operator.mul, uncom_fac_primes, 1)

    # then build composites from leftover primes
    leftover_primes = [i for i in primes if i not in uncom_fac_primes]

    assert leftover_primes
    assert uncom_fac_primes

    # select at most 20 lists, each having at most 30 primes
    # selected from the leftover_primes list
    number_primes = draw(  # pragma: no branch
        st.integers(min_value=1, max_value=20).flatmap(
            lambda n: st.lists(
                st.lists(st.sampled_from(leftover_primes), max_size=30),
                min_size=1,
                max_size=n,
            )
        )
    )

    numbers = [reduce(operator.mul, nums, 1) for nums in number_primes]

    insert_at = draw(st.integers(min_value=0, max_value=len(numbers)))
    numbers.insert(insert_at, uncom_fac)
    return numbers


HYP_SETTINGS = {}
if HC_PRESENT:  # pragma: no branch
    HYP_SETTINGS["suppress_health_check"] = [
        HealthCheck.filter_too_much,
        HealthCheck.too_slow,
    ]
    # the factorization() sometimes takes a long time to finish
    HYP_SETTINGS["deadline"] = 5000

if "--fast" in sys.argv:  # pragma: no cover
    HYP_SETTINGS["max_examples"] = 20


HYP_SLOW_SETTINGS = dict(HYP_SETTINGS)
if "--fast" in sys.argv:  # pragma: no cover
    HYP_SLOW_SETTINGS["max_examples"] = 1
else:
    HYP_SLOW_SETTINGS["max_examples"] = 20


class TestIsPrime(unittest.TestCase):
    def test_very_small_prime(self):
        assert is_prime(23)

    def test_very_small_composite(self):
        assert not is_prime(22)

    def test_small_prime(self):
        assert is_prime(123456791)

    def test_special_composite(self):
        assert not is_prime(10261)

    def test_medium_prime_1(self):
        # nextPrime[2^256]
        assert is_prime(2**256 + 0x129)

    def test_medium_prime_2(self):
        # nextPrime(2^256+0x129)
        assert is_prime(2**256 + 0x12D)

    def test_medium_trivial_composite(self):
        assert not is_prime(2**256 + 0x130)

    def test_medium_non_trivial_composite(self):
        assert not is_prime(2**256 + 0x12F)

    def test_large_prime(self):
        # nextPrime[2^2048]
        assert is_prime(mpz(2) ** 2048 + 0x3D5)

    def test_pseudoprime_base_19(self):
        assert not is_prime(1543267864443420616877677640751301)

    def test_pseudoprime_base_300(self):
        # F. Arnault "Constructing Carmichael Numbers Which Are Strong
        # Pseudoprimes to Several Bases". Journal of Symbolic
        # Computation. 20 (2): 151-161. doi:10.1006/jsco.1995.1042.
        # Section 4.4 Large Example (a pseudoprime to all bases up to
        # 300)
        p = int(
            "29 674 495 668 685 510 550 154 174 642 905 332 730 "
            "771 991 799 853 043 350 995 075 531 276 838 753 171 "
            "770 199 594 238 596 428 121 188 033 664 754 218 345 "
            "562 493 168 782 883".replace(" ", "")
        )

        assert is_prime(p)
        for _ in range(10):
            if not is_prime(p * (313 * (p - 1) + 1) * (353 * (p - 1) + 1)):
                break
        else:
            assert False, "composite not detected"


class TestNumbertheory(unittest.TestCase):
    def test_gcd(self):
        assert gcd(3 * 5 * 7, 3 * 5 * 11, 3 * 5 * 13) == 3 * 5
        assert gcd([3 * 5 * 7, 3 * 5 * 11, 3 * 5 * 13]) == 3 * 5
        assert gcd(3) == 3

    @unittest.skipUnless(
        HC_PRESENT,
        "Hypothesis 2.0.0 can't be made tolerant of hard to "
        "meet requirements (like `is_prime()`), the test "
        "case times-out on it",
    )
    @settings(**HYP_SLOW_SETTINGS)
    @example([877 * 1151, 877 * 1009])
    @given(st_comp_with_com_fac())
    def test_gcd_with_com_factor(self, numbers):
        n = gcd(numbers)
        assert 1 in numbers or n != 1
        for i in numbers:
            assert i % n == 0

    @unittest.skipUnless(
        HC_PRESENT,
        "Hypothesis 2.0.0 can't be made tolerant of hard to "
        "meet requirements (like `is_prime()`), the test "
        "case times-out on it",
    )
    @settings(**HYP_SLOW_SETTINGS)
    @example([1151, 1069, 1009])
    @given(st_comp_no_com_fac())
    def test_gcd_with_uncom_factor(self, numbers):
        n = gcd(numbers)
        assert n == 1

    @settings(**HYP_SLOW_SETTINGS)
    @given(
        st.lists(
            st.integers(min_value=1, max_value=2**8192),
            min_size=1,
            max_size=20,
        )
    )
    def test_gcd_with_random_numbers(self, numbers):
        n = gcd(numbers)
        for i in numbers:
            # check that at least it's a divider
            assert i % n == 0

    def test_lcm(self):
        assert lcm(3, 5 * 3, 7 * 3) == 3 * 5 * 7
        assert lcm([3, 5 * 3, 7 * 3]) == 3 * 5 * 7
        assert lcm(3) == 3

    @settings(**HYP_SLOW_SETTINGS)
    @given(
        st.lists(
            st.integers(min_value=1, max_value=2**8192),
            min_size=1,
            max_size=20,
        )
    )
    def test_lcm_with_random_numbers(self, numbers):
        n = lcm(numbers)
        for i in numbers:
            assert n % i == 0

    @unittest.skipUnless(
        HC_PRESENT,
        "Hypothesis 2.0.0 can't be made tolerant of hard to "
        "meet requirements (like `is_prime()`), the test "
        "case times-out on it",
    )
    @settings(**HYP_SLOW_SETTINGS)
    @given(st_num_square_prime())
    def test_square_root_mod_prime(self, vals):
        square, prime = vals

        calc = square_root_mod_prime(square, prime)
        assert calc * calc % prime == square

    @pytest.mark.slow
    @settings(**HYP_SLOW_SETTINGS)
    @given(st.integers(min_value=1, max_value=10**12))
    @example(265399 * 1526929)
    @example(373297**2 * 553991)
    def test_factorization(self, num):
        factors = factorization(num)
        mult = 1
        for i in factors:
            mult *= i[0] ** i[1]
        assert mult == num

    def test_factorisation_smallprimes(self):
        exp = 101 * 103
        assert 101 in smallprimes
        assert 103 in smallprimes
        factors = factorization(exp)
        mult = 1
        for i in factors:
            mult *= i[0] ** i[1]
        assert mult == exp

    def test_factorisation_not_smallprimes(self):
        exp = 1231 * 1237
        assert 1231 not in smallprimes
        assert 1237 not in smallprimes
        factors = factorization(exp)
        mult = 1
        for i in factors:
            mult *= i[0] ** i[1]
        assert mult == exp

    def test_jacobi_with_zero(self):
        assert jacobi(0, 3) == 0

    def test_jacobi_with_one(self):
        assert jacobi(1, 3) == 1

    @settings(**HYP_SLOW_SETTINGS)
    @given(st.integers(min_value=3, max_value=1000).filter(lambda x: x % 2))
    def test_jacobi(self, mod):
        mod = mpz(mod)
        if is_prime(mod):
            squares = set()
            for root in range(1, mod):
                root = mpz(root)
                assert jacobi(root * root, mod) == 1
                squares.add(root * root % mod)
            for i in range(1, mod):
                if i not in squares:
                    i = mpz(i)
                    assert jacobi(i, mod) == -1
        else:
            factors = factorization(mod)
            for a in range(1, mod):
                c = 1
                for i in factors:
                    c *= jacobi(a, i[0]) ** i[1]
                assert c == jacobi(a, mod)

    @settings(**HYP_SLOW_SETTINGS)
    @given(st_two_nums_rel_prime())
    def test_inverse_mod(self, nums):
        num, mod = nums

        inv = inverse_mod(num, mod)

        assert 0 < inv < mod
        assert num * inv % mod == 1

    def test_inverse_mod_with_zero(self):
        assert 0 == inverse_mod(0, 11)
